Synthetic models for posterior distributions¶
Marco Raveri (marco.raveri@unige.it), Cyrille Doux (doux@lpsc.in2p3.fr), Shivam Pandey (shivampcosmo@gmail.com)
In this notebook we show how to build normalizing flow syntetic models for posterior distributions, as in Raveri, Doux and Pandey (2024), arXiv:XXXX.XXXX.
Notebook setup:¶
# Show plots inline, and load main getdist plot module and samples class
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2
# import libraries:
import sys, os
sys.path.insert(0,os.path.realpath(os.path.join(os.getcwd(),'../..')))
from getdist import plots, MCSamples
from getdist.gaussian_mixtures import GaussianND
import getdist
getdist.chains.print_load_details = False
import scipy
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
# tensorflow imports:
import tensorflow as tf
import tensorflow_probability as tfp
# import the tensiometer tools that we need:
import tensiometer
from tensiometer import utilities
from tensiometer import synthetic_probability
# getdist settings to ensure consistency of plots:
getdist_settings = {'ignore_rows': 0.0,
'smooth_scale_2D': 0.3,
'smooth_scale_1D': 0.3,
}
The autoreload extension is already loaded. To reload it, use: %reload_ext autoreload
2024-07-16 08:42:45.936280: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
We start by building a random Gaussian mixture that we are going to use for tests:
# define the parameters of the problem:
dim = 6
num_gaussians = 3
num_samples = 10000
# we seed the random number generator to get reproducible results:
seed = 100
np.random.seed(seed)
# we define the range for the means and covariances:
mean_range = (-0.5, 0.5)
cov_scale = 0.4**2
# means and covs:
means = np.random.uniform(mean_range[0], mean_range[1], num_gaussians*dim).reshape(num_gaussians, dim)
weights = np.random.rand(num_gaussians)
weights = weights / np.sum(weights)
covs = [cov_scale*utilities.vector_to_PDM(np.random.rand(int(dim*(dim+1)/2))) for _ in range(num_gaussians)]
# initialize distribution:
distribution = tfp.distributions.Mixture(
cat=tfp.distributions.Categorical(probs=weights),
components=[
tfp.distributions.MultivariateNormalTriL(loc=_m, scale_tril=tf.linalg.cholesky(_c))
for _m, _c in zip(means, covs)
], name='Mixture')
# sample the distribution:
samples = distribution.sample(num_samples).numpy()
# calculate log posteriors:
logP = distribution.log_prob(samples).numpy()
# create MCSamples from the samples:
chain = MCSamples(samples=samples,
settings=getdist_settings,
loglikes=-logP,
name_tag='Mixture',
)
# we make a sanity check plot:
g = plots.get_subplot_plotter()
g.triangle_plot(chain, filled=True)
2024-07-17 04:24:26.998494: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 79086 MB memory: -> device: 0, name: NVIDIA A100-SXM4-80GB, pci bus id: 0000:4b:00.0, compute capability: 8.0 2024-07-17 04:24:27.503720: I tensorflow/core/util/cuda_solvers.cc:179] Creating GpuSolver handles for stream 0x55555cce3990
Base example:¶
Train a normalizing flow on samples of a given distribution
We initialize and train the normalizing flow on samples of the distribution we have just defined:
kwargs = {
'feedback': 2,
'plot_every': 1000,
'pop_size': 1,
#'cache_dir': 'test', # set this to a directory to cache the results
#'root_name': 'test', # sets the name of the flow for the cache files
}
flow = tensiometer.synthetic_probability.flow_from_chain(chain, # parameter difference chain
**kwargs)
* Initializing samples
- flow name: Mixture_flow
- precision: <dtype: 'float32'>
- flow parameters and ranges:
param1 : [-1.25138, 1.32089]
param2 : [-1.36877, 1.33678]
param3 : [-1.45129, 0.793965]
param4 : [-0.892745, 1.4007]
param5 : [-1.73913, 1.22002]
param6 : [-1.59422, 1.21589]
- periodic parameters: []
* Initializing fixed bijector
- using prior bijector: ranges
- rescaling samples
* Initializing trainable bijector
Building Autoregressive Flow
- # parameters : 6
- periodic parameters : None
- # transformations : 8
- hidden_units : [12, 12]
- transformation_type : affine
- autoregressive_type : masked
- permutations : True
- scale_roto_shift : False
- activation : <function asinh at 0x1554bf817b50>
- activation : <function asinh at 0x1b9c82d30>
* Initializing training dataset
- 9000/1000 training/test samples and uniform weights
* Initializing transformed distribution
* Initializing loss function
- using standard loss function
- trainable parameters : 3168
- maximum learning rate: 0.001
- minimum learning rate: 1e-06
* Training
Epoch 1/100
20/20 - 8s - loss: 8.5185 - val_loss: 8.4926 - lr: 0.0010 - 8s/epoch - 381ms/step
Epoch 2/100
20/20 - 0s - loss: 8.4812 - val_loss: 8.4618 - lr: 0.0010 - 253ms/epoch - 13ms/step
Epoch 3/100
20/20 - 0s - loss: 8.4446 - val_loss: 8.4180 - lr: 0.0010 - 340ms/epoch - 17ms/step
Epoch 4/100
20/20 - 0s - loss: 8.4078 - val_loss: 8.3834 - lr: 0.0010 - 267ms/epoch - 13ms/step
Epoch 5/100
20/20 - 0s - loss: 8.3779 - val_loss: 8.3592 - lr: 0.0010 - 259ms/epoch - 13ms/step
Epoch 6/100
20/20 - 0s - loss: 8.3559 - val_loss: 8.3467 - lr: 0.0010 - 256ms/epoch - 13ms/step
Epoch 7/100
20/20 - 0s - loss: 8.3444 - val_loss: 8.3362 - lr: 0.0010 - 248ms/epoch - 12ms/step
Epoch 8/100
20/20 - 0s - loss: 8.3380 - val_loss: 8.3340 - lr: 0.0010 - 274ms/epoch - 14ms/step
Epoch 9/100
20/20 - 0s - loss: 8.3334 - val_loss: 8.3285 - lr: 0.0010 - 263ms/epoch - 13ms/step
Epoch 10/100
20/20 - 0s - loss: 8.3289 - val_loss: 8.3260 - lr: 0.0010 - 257ms/epoch - 13ms/step
Epoch 11/100
20/20 - 0s - loss: 8.3250 - val_loss: 8.3266 - lr: 0.0010 - 290ms/epoch - 15ms/step
Epoch 12/100
20/20 - 0s - loss: 8.3203 - val_loss: 8.3185 - lr: 0.0010 - 262ms/epoch - 13ms/step
Epoch 13/100
20/20 - 0s - loss: 8.3133 - val_loss: 8.3190 - lr: 0.0010 - 253ms/epoch - 13ms/step
Epoch 14/100
20/20 - 0s - loss: 8.3077 - val_loss: 8.3082 - lr: 0.0010 - 361ms/epoch - 18ms/step
Epoch 15/100
20/20 - 0s - loss: 8.2994 - val_loss: 8.3056 - lr: 0.0010 - 285ms/epoch - 14ms/step
Epoch 16/100
20/20 - 0s - loss: 8.2887 - val_loss: 8.2959 - lr: 0.0010 - 284ms/epoch - 14ms/step
Epoch 17/100
20/20 - 0s - loss: 8.2769 - val_loss: 8.2877 - lr: 0.0010 - 266ms/epoch - 13ms/step
Epoch 18/100
20/20 - 0s - loss: 8.2608 - val_loss: 8.2689 - lr: 0.0010 - 263ms/epoch - 13ms/step
Epoch 19/100
20/20 - 0s - loss: 8.2389 - val_loss: 8.2514 - lr: 0.0010 - 263ms/epoch - 13ms/step
Epoch 20/100
20/20 - 0s - loss: 8.2118 - val_loss: 8.2311 - lr: 0.0010 - 275ms/epoch - 14ms/step
Epoch 21/100
20/20 - 0s - loss: 8.1773 - val_loss: 8.1955 - lr: 0.0010 - 274ms/epoch - 14ms/step
Epoch 22/100
20/20 - 0s - loss: 8.1350 - val_loss: 8.1495 - lr: 0.0010 - 266ms/epoch - 13ms/step
Epoch 23/100
20/20 - 0s - loss: 8.0882 - val_loss: 8.0951 - lr: 0.0010 - 261ms/epoch - 13ms/step
Epoch 24/100
20/20 - 0s - loss: 8.0374 - val_loss: 8.0674 - lr: 0.0010 - 269ms/epoch - 13ms/step
Epoch 25/100
20/20 - 0s - loss: 7.9912 - val_loss: 7.9957 - lr: 0.0010 - 290ms/epoch - 15ms/step
Epoch 26/100
20/20 - 0s - loss: 7.9526 - val_loss: 7.9794 - lr: 0.0010 - 254ms/epoch - 13ms/step
Epoch 27/100
20/20 - 0s - loss: 7.9215 - val_loss: 7.9354 - lr: 0.0010 - 242ms/epoch - 12ms/step
Epoch 28/100
20/20 - 0s - loss: 7.8945 - val_loss: 7.9014 - lr: 0.0010 - 282ms/epoch - 14ms/step
Epoch 29/100
20/20 - 0s - loss: 7.8695 - val_loss: 7.8885 - lr: 0.0010 - 259ms/epoch - 13ms/step
Epoch 30/100
20/20 - 0s - loss: 7.8493 - val_loss: 7.8734 - lr: 0.0010 - 247ms/epoch - 12ms/step
Epoch 31/100
20/20 - 0s - loss: 7.8329 - val_loss: 7.8437 - lr: 0.0010 - 254ms/epoch - 13ms/step
Epoch 32/100
20/20 - 0s - loss: 7.8082 - val_loss: 7.8088 - lr: 0.0010 - 243ms/epoch - 12ms/step
Epoch 33/100
20/20 - 0s - loss: 7.7919 - val_loss: 7.8039 - lr: 0.0010 - 253ms/epoch - 13ms/step
Epoch 34/100
20/20 - 0s - loss: 7.7795 - val_loss: 7.7817 - lr: 0.0010 - 306ms/epoch - 15ms/step
Epoch 35/100
20/20 - 0s - loss: 7.7623 - val_loss: 7.7945 - lr: 0.0010 - 251ms/epoch - 13ms/step
Epoch 36/100
20/20 - 0s - loss: 7.7515 - val_loss: 7.7558 - lr: 0.0010 - 376ms/epoch - 19ms/step
Epoch 37/100
20/20 - 0s - loss: 7.7360 - val_loss: 7.7478 - lr: 0.0010 - 286ms/epoch - 14ms/step
Epoch 38/100
20/20 - 0s - loss: 7.7273 - val_loss: 7.7247 - lr: 0.0010 - 260ms/epoch - 13ms/step
Epoch 39/100
20/20 - 0s - loss: 7.7177 - val_loss: 7.7237 - lr: 0.0010 - 295ms/epoch - 15ms/step
Epoch 40/100
20/20 - 0s - loss: 7.7076 - val_loss: 7.7115 - lr: 0.0010 - 243ms/epoch - 12ms/step
Epoch 41/100
20/20 - 0s - loss: 7.6971 - val_loss: 7.6988 - lr: 0.0010 - 244ms/epoch - 12ms/step
Epoch 42/100
20/20 - 0s - loss: 7.6874 - val_loss: 7.7009 - lr: 0.0010 - 246ms/epoch - 12ms/step
Epoch 43/100
20/20 - 0s - loss: 7.6788 - val_loss: 7.6929 - lr: 0.0010 - 240ms/epoch - 12ms/step
Epoch 44/100
20/20 - 0s - loss: 7.6770 - val_loss: 7.6719 - lr: 0.0010 - 238ms/epoch - 12ms/step
Epoch 45/100
20/20 - 0s - loss: 7.6681 - val_loss: 7.6810 - lr: 0.0010 - 245ms/epoch - 12ms/step
Epoch 46/100
20/20 - 0s - loss: 7.6609 - val_loss: 7.6590 - lr: 0.0010 - 257ms/epoch - 13ms/step
Epoch 47/100
20/20 - 0s - loss: 7.6537 - val_loss: 7.6703 - lr: 0.0010 - 243ms/epoch - 12ms/step
Epoch 48/100
20/20 - 0s - loss: 7.6468 - val_loss: 7.6607 - lr: 0.0010 - 249ms/epoch - 12ms/step
Epoch 49/100
20/20 - 0s - loss: 7.6421 - val_loss: 7.6423 - lr: 0.0010 - 302ms/epoch - 15ms/step
Epoch 50/100
20/20 - 0s - loss: 7.6422 - val_loss: 7.6583 - lr: 0.0010 - 240ms/epoch - 12ms/step
Epoch 51/100
20/20 - 0s - loss: 7.6347 - val_loss: 7.6370 - lr: 0.0010 - 244ms/epoch - 12ms/step
Epoch 52/100
20/20 - 0s - loss: 7.6282 - val_loss: 7.6388 - lr: 0.0010 - 280ms/epoch - 14ms/step
Epoch 53/100
20/20 - 0s - loss: 7.6238 - val_loss: 7.6375 - lr: 0.0010 - 249ms/epoch - 12ms/step
Epoch 54/100
20/20 - 0s - loss: 7.6194 - val_loss: 7.6368 - lr: 0.0010 - 254ms/epoch - 13ms/step
Epoch 55/100
20/20 - 0s - loss: 7.6115 - val_loss: 7.6231 - lr: 0.0010 - 232ms/epoch - 12ms/step
Epoch 56/100
20/20 - 0s - loss: 7.6086 - val_loss: 7.6218 - lr: 0.0010 - 241ms/epoch - 12ms/step
Epoch 57/100
20/20 - 0s - loss: 7.6034 - val_loss: 7.6238 - lr: 0.0010 - 235ms/epoch - 12ms/step
Epoch 58/100
20/20 - 0s - loss: 7.5988 - val_loss: 7.6100 - lr: 0.0010 - 337ms/epoch - 17ms/step
Epoch 59/100
20/20 - 0s - loss: 7.5951 - val_loss: 7.6171 - lr: 0.0010 - 243ms/epoch - 12ms/step
Epoch 60/100
20/20 - 0s - loss: 7.5891 - val_loss: 7.6109 - lr: 0.0010 - 236ms/epoch - 12ms/step
Epoch 61/100
20/20 - 0s - loss: 7.5846 - val_loss: 7.6027 - lr: 0.0010 - 276ms/epoch - 14ms/step
Epoch 62/100
20/20 - 0s - loss: 7.5819 - val_loss: 7.6198 - lr: 0.0010 - 243ms/epoch - 12ms/step
Epoch 63/100
20/20 - 0s - loss: 7.5787 - val_loss: 7.6037 - lr: 0.0010 - 235ms/epoch - 12ms/step
Epoch 64/100
20/20 - 0s - loss: 7.5738 - val_loss: 7.6050 - lr: 0.0010 - 236ms/epoch - 12ms/step
Epoch 65/100
20/20 - 0s - loss: 7.5737 - val_loss: 7.6169 - lr: 0.0010 - 236ms/epoch - 12ms/step
Epoch 66/100
20/20 - 0s - loss: 7.5654 - val_loss: 7.5930 - lr: 0.0010 - 232ms/epoch - 12ms/step
Epoch 67/100
20/20 - 0s - loss: 7.5590 - val_loss: 7.6027 - lr: 0.0010 - 235ms/epoch - 12ms/step
Epoch 68/100
20/20 - 0s - loss: 7.5597 - val_loss: 7.5879 - lr: 0.0010 - 240ms/epoch - 12ms/step
Epoch 69/100
20/20 - 0s - loss: 7.5555 - val_loss: 7.5947 - lr: 0.0010 - 262ms/epoch - 13ms/step
Epoch 70/100
20/20 - 0s - loss: 7.5496 - val_loss: 7.5867 - lr: 0.0010 - 254ms/epoch - 13ms/step
Epoch 71/100
20/20 - 0s - loss: 7.5452 - val_loss: 7.6035 - lr: 0.0010 - 236ms/epoch - 12ms/step
Epoch 72/100
20/20 - 0s - loss: 7.5493 - val_loss: 7.6085 - lr: 0.0010 - 234ms/epoch - 12ms/step
Epoch 73/100
20/20 - 0s - loss: 7.5426 - val_loss: 7.5949 - lr: 0.0010 - 250ms/epoch - 12ms/step
Epoch 74/100
20/20 - 0s - loss: 7.5399 - val_loss: 7.5844 - lr: 0.0010 - 251ms/epoch - 13ms/step
Epoch 75/100
20/20 - 0s - loss: 7.5347 - val_loss: 7.5835 - lr: 0.0010 - 239ms/epoch - 12ms/step
Epoch 76/100
20/20 - 0s - loss: 7.5309 - val_loss: 7.5806 - lr: 0.0010 - 245ms/epoch - 12ms/step
Epoch 77/100
20/20 - 0s - loss: 7.5279 - val_loss: 7.5868 - lr: 0.0010 - 265ms/epoch - 13ms/step
Epoch 78/100
20/20 - 0s - loss: 7.5233 - val_loss: 7.5845 - lr: 0.0010 - 253ms/epoch - 13ms/step
Epoch 79/100
20/20 - 0s - loss: 7.5226 - val_loss: 7.5788 - lr: 0.0010 - 242ms/epoch - 12ms/step
Epoch 80/100
20/20 - 0s - loss: 7.5207 - val_loss: 7.5831 - lr: 0.0010 - 350ms/epoch - 17ms/step
Epoch 81/100
20/20 - 0s - loss: 7.5156 - val_loss: 7.5663 - lr: 0.0010 - 242ms/epoch - 12ms/step
Epoch 82/100
20/20 - 0s - loss: 7.5137 - val_loss: 7.5698 - lr: 0.0010 - 239ms/epoch - 12ms/step
Epoch 83/100
20/20 - 0s - loss: 7.5123 - val_loss: 7.5661 - lr: 0.0010 - 244ms/epoch - 12ms/step
Epoch 84/100
20/20 - 0s - loss: 7.5069 - val_loss: 7.5690 - lr: 0.0010 - 283ms/epoch - 14ms/step
Epoch 85/100
20/20 - 0s - loss: 7.5071 - val_loss: 7.5752 - lr: 0.0010 - 239ms/epoch - 12ms/step
Epoch 86/100
20/20 - 0s - loss: 7.5065 - val_loss: 7.5616 - lr: 0.0010 - 250ms/epoch - 12ms/step
Epoch 87/100
20/20 - 0s - loss: 7.5066 - val_loss: 7.5738 - lr: 0.0010 - 241ms/epoch - 12ms/step
Epoch 88/100
20/20 - 0s - loss: 7.5040 - val_loss: 7.5548 - lr: 0.0010 - 257ms/epoch - 13ms/step
Epoch 89/100
20/20 - 0s - loss: 7.4992 - val_loss: 7.5609 - lr: 0.0010 - 277ms/epoch - 14ms/step
Epoch 90/100
20/20 - 0s - loss: 7.4947 - val_loss: 7.5649 - lr: 0.0010 - 237ms/epoch - 12ms/step
Epoch 91/100
20/20 - 0s - loss: 7.4954 - val_loss: 7.5488 - lr: 0.0010 - 242ms/epoch - 12ms/step
Epoch 92/100
20/20 - 0s - loss: 7.4932 - val_loss: 7.5524 - lr: 0.0010 - 239ms/epoch - 12ms/step
Epoch 93/100
20/20 - 0s - loss: 7.4928 - val_loss: 7.5607 - lr: 0.0010 - 236ms/epoch - 12ms/step
Epoch 94/100
20/20 - 0s - loss: 7.4910 - val_loss: 7.5493 - lr: 0.0010 - 251ms/epoch - 13ms/step
Epoch 95/100
20/20 - 0s - loss: 7.4877 - val_loss: 7.5482 - lr: 0.0010 - 277ms/epoch - 14ms/step
Epoch 96/100
20/20 - 0s - loss: 7.4861 - val_loss: 7.5458 - lr: 0.0010 - 238ms/epoch - 12ms/step
Epoch 97/100
20/20 - 0s - loss: 7.4822 - val_loss: 7.5534 - lr: 0.0010 - 242ms/epoch - 12ms/step
Epoch 98/100
20/20 - 0s - loss: 7.4819 - val_loss: 7.5452 - lr: 0.0010 - 300ms/epoch - 15ms/step
Epoch 99/100
20/20 - 0s - loss: 7.4783 - val_loss: 7.5455 - lr: 0.0010 - 267ms/epoch - 13ms/step
Epoch 100/100
20/20 - 0s - loss: 7.4740 - val_loss: 7.5429 - lr: 0.0010 - 241ms/epoch - 12ms/step
* Population optimizer:
- best model is number 1
- best loss function is 7.47
- best validation loss function is 7.54
- population losses [7.54]
# we can plot training summaries to make sure training went smoothly:
flow.training_plot()
# and we can print the training summary:
flow.print_training_summary()
loss : 7.4740 val_loss : 7.5429 lr : 0.0010 chi2Z_ks : 0.0325 chi2Z_ks_p : 0.2368 loss_rate : -0.0042 val_loss_rate: -0.0026
# we can triangle plot the flow to see how well it has learned the target distribution:
g = plots.get_subplot_plotter()
g.triangle_plot([chain, flow.MCSamples(20000)],
params=flow.param_names,
filled=True)
# this looks nice but not perfect, let's train for longer:
flow.feedback = 1
flow.train(epochs=300, verbose=-1); # verbose = -1 uses tqdm progress bar
0epoch [00:00, ?epoch/s]
0batch [00:00, ?batch/s]
# we can plot training summaries to make sure training went smoothly:
flow.training_plot()
<Figure size 640x480 with 0 Axes>
If you train for long enough you should start seeing the learning rate adapting to the non-improving (noisy) loss function.
This means that the flow is learning finer and finer features and a good indication that training is converging. If you push it further, at some point, the flow will start overfitting and training will stop.
Now let's look at how the marginal distributions look like:
# we can triangle plot the flow to see how well it has learned the target distribution:
g = plots.get_subplot_plotter()
g.triangle_plot([chain,
flow.MCSamples(20000) # this flow method returns a MCSamples object
],
params=flow.param_names,
filled=True)
We can use the trained flow to perform several operations. For example let's compute log-likelihoods
samples = flow.MCSamples(20000)
logP = flow.log_probability(flow.cast(samples.samples)).numpy()
samples.addDerived(logP, name='logP', label='\\log P')
samples.updateBaseStatistics();
# now let's plot everything:
g = plots.get_subplot_plotter()
g.triangle_plot([samples, chain],
plot_3d_with_param='logP',
filled=False)
We can appreciate here a beautiful display of a projection effect. The marginal distribution of $p_5$ is peaked at a positive value while the logP plot clearly shows that the peak of the full distribution is the negative one.
If you are interested in understanding systematically these types of effect, check the corresponding tensiometer tutorial!
Average flow example:¶
A more advanced flow model consists in training several flows and using a weighted mixture normalizing flow model.
This flow model improves the variance of the flow in regions that are scarse with samples (as different flow models will allucinate differently)...
Let's try averaging 5 flow models (note that we could do this in parallel with MPI on bigger machines):
kwargs = {
'feedback': 1,
'verbose': -1,
'plot_every': 1000,
'pop_size': 1,
'num_flows': 5,
'epochs': 400,
}
average_flow = tensiometer.synthetic_probability.average_flow_from_chain(chain, # parameter difference chain
**kwargs)
Warning: MPI is incompatible with no cache. Disabling MPI. Training flow 0 * Initializing samples * Initializing fixed bijector * Initializing trainable bijector * Initializing training dataset * Initializing transformed distribution * Initializing loss function * Training
0epoch [00:00, ?epoch/s]
0batch [00:00, ?batch/s]
Training flow 1 * Initializing samples * Initializing fixed bijector * Initializing trainable bijector * Initializing training dataset * Initializing transformed distribution * Initializing loss function * Training
0epoch [00:00, ?epoch/s]
0batch [00:00, ?batch/s]
Training flow 2 * Initializing samples * Initializing fixed bijector * Initializing trainable bijector * Initializing training dataset * Initializing transformed distribution * Initializing loss function * Training
0epoch [00:00, ?epoch/s]
0batch [00:00, ?batch/s]
Training flow 3 * Initializing samples * Initializing fixed bijector * Initializing trainable bijector * Initializing training dataset * Initializing transformed distribution * Initializing loss function * Training
0epoch [00:00, ?epoch/s]
0batch [00:00, ?batch/s]
Training flow 4 * Initializing samples * Initializing fixed bijector * Initializing trainable bijector * Initializing training dataset * Initializing transformed distribution * Initializing loss function * Training
0epoch [00:00, ?epoch/s]
0batch [00:00, ?batch/s]
# most methods are implemented for the average flow as well:
average_flow.training_plot()
# and we can print the training summary, which in this case contains more info:
average_flow.print_training_summary()
Number of flows: 5 Flow weights : [0.21 0.21 0.19 0.17 0.21] loss : [7.29 7.29 7.34 7.36 7.27] val_loss : [7.4 7.37 7.47 7.58 7.36] lr : [3.16e-05 1.00e-04 1.00e-06 1.00e-06 3.16e-04] chi2Z_ks : [0.04 0.03 0.04 0.02 0.02] chi2Z_ks_p : [0.1 0.44 0.1 0.64 0.76] loss_rate : [-2.86e-05 3.40e-04 -9.54e-07 2.86e-06 -6.83e-04] val_loss_rate: [ 4.29e-05 -1.93e-03 -8.49e-05 4.63e-05 -4.09e-03]
avg_samples = average_flow.MCSamples(20000)
avg_samples.name_tag = 'Average Flow'
temp_samples = [_f.MCSamples(20000) for _f in average_flow.flows]
for i, _s in enumerate(temp_samples):
_s.name_tag = _s.name_tag + f'_{i}'
# let's plot the flows:
g = plots.get_subplot_plotter()
g.triangle_plot([chain, avg_samples] + temp_samples,
filled=False)
WARNING:tensorflow:5 out of the last 8 calls to <function FlowCallback.log_probability at 0x1d2b024c0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:5 out of the last 8 calls to <function FlowCallback.log_probability at 0x1d2b024c0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:5 out of the last 7 calls to <function FlowCallback.sample at 0x1d3d5d700> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:5 out of the last 7 calls to <function FlowCallback.sample at 0x1d3d5d700> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:6 out of the last 9 calls to <function FlowCallback.log_probability at 0x1d3d0dca0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:6 out of the last 9 calls to <function FlowCallback.log_probability at 0x1d3d0dca0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:6 out of the last 8 calls to <function FlowCallback.sample at 0x1d4e65820> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:6 out of the last 8 calls to <function FlowCallback.sample at 0x1d4e65820> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
logP = average_flow.log_probability(average_flow.cast(avg_samples.samples)).numpy()
avg_samples.addDerived(logP, name='logP', label='\\log P')
avg_samples.updateBaseStatistics();
# now let's plot everything:
g = plots.get_subplot_plotter()
g.triangle_plot([avg_samples, chain],
plot_3d_with_param='logP',
filled=False)
Real world application: joint parameter estimation¶
In this example we perform a flow-based analysis of a joint posterior.
The idea is that we have posteriors samples from two independent experiments, we learn the two posteriors and then we combine them to form the joint posterior. Note that we are assuming - as it is true in this example - that the prior is the same among the two experiments and flat (so that we are not duplicating the prior).
This procedure was used, for example, in Gatti, Campailla et al (2024), arXiv:2405.10881.
# we start by loading up the posteriors:
# load the samples (remove no burn in since the example chains have already been cleaned):
chains_dir = './../../test_chains/'
# the Planck 2018 TTTEEE chain:
chain_1 = getdist.mcsamples.loadMCSamples(file_root=chains_dir+'Planck18TTTEEE', no_cache=True, settings=getdist_settings)
# the DES Y1 3x2 chain:
chain_2 = getdist.mcsamples.loadMCSamples(file_root=chains_dir+'DES', no_cache=True, settings=getdist_settings)
# the joint chain:
chain_12 = getdist.mcsamples.loadMCSamples(file_root=chains_dir+'Planck18TTTEEE_DES', no_cache=True, settings=getdist_settings)
# let's add omegab as a derived parameter:
for _ch in [chain_1, chain_2, chain_12]:
_p = _ch.getParams()
_h = _p.H0 / 100.
_ch.addDerived(_p.omegabh2 / _h**2, name='omegab', label='\\Omega_b')
_ch.updateBaseStatistics()
# we define the parameters of the problem:
param_names = ['H0', 'omegam', 'sigma8', 'ns', 'omegab']
# and then do a sanity check plot:
g = plots.get_subplot_plotter()
g.triangle_plot([chain_1, chain_2, chain_12], params=param_names, filled=True)
# we then train the flows on the base parameters that we want to combine (note that for this exercise we should include all shared parameters):
kwargs = {
'feedback': 1,
'verbose': -1,
'plot_every': 1000,
'pop_size': 1,
'num_flows': 3,
'epochs': 400,
}
# actual flow training:
flow_1 = tensiometer.synthetic_probability.average_flow_from_chain(chain_1, param_names=param_names, **kwargs)
flow_2 = tensiometer.synthetic_probability.average_flow_from_chain(chain_2, param_names=param_names, **kwargs)
flow_12 = tensiometer.synthetic_probability.average_flow_from_chain(chain_12, param_names=param_names, **kwargs)
# plot to make sure training went well:
flow_1.training_plot()
flow_2.training_plot()
flow_12.training_plot()
Warning: MPI is incompatible with no cache. Disabling MPI. Training flow 0 * Initializing samples * Initializing fixed bijector * Initializing trainable bijector * Initializing training dataset * Initializing transformed distribution * Initializing loss function * Training
0epoch [00:00, ?epoch/s]
0batch [00:00, ?batch/s]
Training flow 1 * Initializing samples * Initializing fixed bijector * Initializing trainable bijector * Initializing training dataset * Initializing transformed distribution * Initializing loss function * Training
0epoch [00:00, ?epoch/s]
0batch [00:00, ?batch/s]
Training flow 2 * Initializing samples * Initializing fixed bijector * Initializing trainable bijector * Initializing training dataset * Initializing transformed distribution * Initializing loss function * Training
0epoch [00:00, ?epoch/s]
0batch [00:00, ?batch/s]
Warning: MPI is incompatible with no cache. Disabling MPI. Training flow 0 * Initializing samples * Initializing fixed bijector * Initializing trainable bijector * Initializing training dataset * Initializing transformed distribution * Initializing loss function * Training
0epoch [00:00, ?epoch/s]
0batch [00:00, ?batch/s]
Training flow 1 * Initializing samples * Initializing fixed bijector * Initializing trainable bijector * Initializing training dataset * Initializing transformed distribution * Initializing loss function * Training
0epoch [00:00, ?epoch/s]
0batch [00:00, ?batch/s]
Training flow 2 * Initializing samples * Initializing fixed bijector * Initializing trainable bijector * Initializing training dataset * Initializing transformed distribution * Initializing loss function * Training
0epoch [00:00, ?epoch/s]
0batch [00:00, ?batch/s]
Warning: MPI is incompatible with no cache. Disabling MPI. Training flow 0 * Initializing samples * Initializing fixed bijector * Initializing trainable bijector * Initializing training dataset * Initializing transformed distribution * Initializing loss function * Training
0epoch [00:00, ?epoch/s]
0batch [00:00, ?batch/s]
Training flow 1 * Initializing samples * Initializing fixed bijector * Initializing trainable bijector * Initializing training dataset * Initializing transformed distribution * Initializing loss function * Training
0epoch [00:00, ?epoch/s]
0batch [00:00, ?batch/s]
Training flow 2 * Initializing samples * Initializing fixed bijector * Initializing trainable bijector * Initializing training dataset * Initializing transformed distribution * Initializing loss function * Training
0epoch [00:00, ?epoch/s]
0batch [00:00, ?batch/s]
# sanity check triangle plot:
g = plots.get_subplot_plotter()
g.triangle_plot([chain_1, flow_1.MCSamples(20000, settings=getdist_settings),
chain_2, flow_2.MCSamples(20000, settings=getdist_settings),
chain_12, flow_12.MCSamples(20000, settings=getdist_settings),
],
params=param_names,
filled=False)
# we log scale the y axis for the logP plot so that we can appreciate the accuracy of the flow on the tails:
for i in range(len(param_names)):
_ax = g.subplots[i, i]
_ax.set_yscale('log')
_ax.set_ylim([1.e-5, 1.0])
_ax.set_ylabel('$\\log P$')
_ax.tick_params(axis='y', which='both', labelright='on')
_ax.yaxis.set_label_position("right")
# now we can define the joint posterior:
def joint_log_posterior(H0, omegam, sigma8, ns, omegab):
params = [H0, omegam, sigma8, ns, omegab]
return [flow_1.log_probability(flow_1.cast(params)).numpy() + flow_2.log_probability(flow_2.cast(params)).numpy()]
# and sample it:
from cobaya.run import run
from getdist.mcsamples import MCSamplesFromCobaya
parameters = {}
for key in param_names:
parameters[key] = {"prior": {"min": 1.01*max(flow_1.parameter_ranges[key][0], flow_2.parameter_ranges[key][0]),
"max": 0.99*min(flow_1.parameter_ranges[key][1], flow_2.parameter_ranges[key][1])},
"latex": flow_1.param_labels[flow_1.param_names.index(key)]}
info = {
"likelihood": {"joint_log_posterior": joint_log_posterior},
"params": parameters,
}
# MCMC sample:
# we need a \sim good initial proposal and starting point, we get them from one of the flows:
flow_1_samples = flow_1.sample(10000)
flow_1_logPs = flow_1.log_probability(flow_1_samples).numpy()
flow_1_maxP_sample = flow_1_samples[np.argmax(flow_1_logPs)].numpy()
# we need a good starting point otherwise this will take long...
for _i, _k in enumerate(parameters.keys()):
info['params'][_k]['ref'] = flow_1_maxP_sample[_i]
info["sampler"] = {"mcmc":
{'covmat': np.cov(flow_1_samples.numpy().T),
'covmat_params': param_names,
'max_tries': np.inf,
'Rminus1_stop': 0.01,
'learn_proposal_Rminus1_max': 30.,
'learn_proposal_Rminus1_max_early': 30.,
'measure_speeds': False,
'Rminus1_single_split': 10,
}}
info['debug'] = 100 # note this is an insane hack to disable very verbose output...
updated_info, sampler = run(info)
joint_chain = MCSamplesFromCobaya(updated_info, sampler.products()["sample"], ignore_rows=0.3, settings=getdist_settings)
WARNING:tensorflow:5 out of the last 5 calls to <function average_flow.sample at 0x207d5f3a0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:5 out of the last 5 calls to <function average_flow.sample at 0x207d5f3a0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
## Nested sampling sample:
#_dim = len(flow_1.param_names)
#
#info["sampler"] = {"polychord": {'nlive': 50*_dim,
# 'measure_speeds': False,
# 'num_repeats': 2*_dim,
# 'nprior': 10*25*_dim,
# 'do_clustering': True,
# 'precision_criterion': 0.01,
# 'boost_posterior': 10,
# 'feedback': 0,
# },
# }
#info['debug'] = 100 # note this is an insane hack to disable very verbose output...
#updated_info, sampler = run(info)
#joint_chain = MCSamplesFromCobaya(updated_info, sampler.products()["sample"], settings=getdist_settings)
joint_chain.name_tag = 'Flow Joint'
chain_12.name_tag = 'Real Joint (Planck + DES)'
# sanity check triangle plot:
g = plots.get_subplot_plotter()
g.triangle_plot([joint_chain, chain_12],
params=param_names,
filled=False)
As we can see this works fairly well, given that the two experiments are in some tension - do not overlap significantly.
Make sure you check for the consistency of the experiments you are combining before doing so, to ensure that the joint flow posterior samples a well-trained part of the flows.
You can check the example notebook in this documentation for how to compute tensions between two experiments.
Advanced Topic: accurate likelihood values¶
For some applications we need to push the local accuracy of the flow model. In this case we need to provide exact probability values (up to normalization constant) for the training set.
These are then used to build a part of the loss function that rewards local accuracy of probability values. This second part of the loss function is the estimated evidence error. By default the code adaptively mixes the two loss functions to find an optimal solution.
As a downside we can only train a flow that preserves all the parameters of the distribution, i.e. we cannot train on marginalized parameters (as we have done in the previous examples).
For more details
ev, eer = flow.evidence()
print(f'log(Z) = {ev} +- {eer}')
log(Z) = 0.11070668697357178 +- 0.7633598446846008
We can see that the value is close to what it should be (zero since the original distribution is normalized) but the estimated error is still fairly high.
Since we have (normalized) log P values we can check the local reliability of the normalizing flow:
validation_flow_log10_P = flow.log_probability(flow.cast(chain.samples[flow.test_idx, :])).numpy()/np.log(10.)
validation_samples_log10_P = -chain.loglikes[flow.test_idx]/np.log(10.) # notice the minus sign due to the definition of logP in getdist
training_flow_log10_P = flow.log_probability(flow.cast(chain.samples[flow.training_idx, :])).numpy()/np.log(10.)
training_samples_log10_P = -chain.loglikes[flow.training_idx]/np.log(10.) # notice the minus sign due to the definition of logP in getdist
# do the plot:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
ax1.scatter(training_samples_log10_P - np.amax(training_samples_log10_P), training_flow_log10_P - training_samples_log10_P, s=0.1, label='training')
ax1.scatter(validation_samples_log10_P - np.amax(validation_samples_log10_P), validation_flow_log10_P - validation_samples_log10_P, s=0.5, label='validation')
ax1.legend()
ax1.axhline(0, color='k', linestyle='--')
ax1.set_xlabel('$\log_{10}(P_{\\rm true}/P_{\\rm max})$')
ax1.set_ylabel('$\log_{10}(P_{\\rm flow}) - \log_{10}(P_{\\rm true})$')
ax1.set_ylim(-1.0, 1.0)
ax2.hist(training_flow_log10_P - training_samples_log10_P, bins=50, range=[-1., 1.], density=True, alpha=0.5, label='training')
ax2.hist(validation_flow_log10_P - validation_samples_log10_P, bins=50, range=[-1., 1.], density=True, alpha=0.5, label='validation')
ax2.legend()
ax2.axvline(0, color='k', linestyle='--')
ax2.set_xlabel('$\log_{10}(P_{\\rm flow}) - \log_{10}(P_{\\rm true})$')
ax2.set_xlim([-1.0, 1.0])
plt.tight_layout()
plt.show()
We can clearly see that the local accuracy of the flow in full dimension is not high. As we move to the tails we easily have large errors. The variance of this plot is the estimated error on the evidence, which is rather large and dominated by the outliers in the tails.
Considering average flows usually improves the situation, in particular on the validation sample.
ev, eer = average_flow.evidence()
print(f'log(Z) = {ev} +- {eer}')
log(Z) = 0.06764783710241318 +- 0.4927523136138916
validation_flow_log10_P = average_flow.log_probability(average_flow.cast(chain.samples[average_flow.test_idx, :])).numpy()/np.log(10.)
validation_samples_log10_P = -chain.loglikes[average_flow.test_idx]/np.log(10.) # notice the minus sign due to the definition of logP in getdist
training_flow_log10_P = average_flow.log_probability(average_flow.cast(chain.samples[average_flow.training_idx, :])).numpy()/np.log(10.)
training_samples_log10_P = -chain.loglikes[average_flow.training_idx]/np.log(10.) # notice the minus sign due to the definition of logP in getdist
# do the plot:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
ax1.scatter(training_samples_log10_P - np.amax(training_samples_log10_P), training_flow_log10_P - training_samples_log10_P, s=0.1, label='training')
ax1.scatter(validation_samples_log10_P - np.amax(validation_samples_log10_P), validation_flow_log10_P - validation_samples_log10_P, s=0.5, label='validation')
ax1.legend()
ax1.axhline(0, color='k', linestyle='--')
ax1.set_xlabel('$\log_{10}(P_{\\rm true}/P_{\\rm max})$')
ax1.set_ylabel('$\log_{10}(P_{\\rm flow}) - \log_{10}(P_{\\rm true})$')
ax1.set_ylim(-1.0, 1.0)
ax2.hist(training_flow_log10_P - training_samples_log10_P, bins=50, range=[-1., 1.], density=True, alpha=0.5, label='training')
ax2.hist(validation_flow_log10_P - validation_samples_log10_P, bins=50, range=[-1., 1.], density=True, alpha=0.5, label='validation')
ax2.legend()
ax2.axvline(0, color='k', linestyle='--')
ax2.set_xlabel('$\log_{10}(P_{\\rm flow}) - \log_{10}(P_{\\rm true})$')
ax2.set_xlim([-1.0, 1.0])
plt.tight_layout()
plt.show()
This looks significantly better, and in fact the error on the evidence estimate is half...
If we want to do better we need to train with evidence error loss, as discussed in the reference paper for this example notebook.
kwargs = {
'feedback': 1,
'verbose': -1,
'plot_every': 1000,
'pop_size': 1,
'num_flows': 1,
'epochs': 400,
'loss_mode': 'softadapt',
}
average_flow_2 = tensiometer.synthetic_probability.average_flow_from_chain(chain, # parameter difference chain
**kwargs)
Warning: MPI is incompatible with no cache. Disabling MPI. Training flow 0 * Initializing samples * Initializing fixed bijector * Initializing trainable bijector * Initializing training dataset * Initializing transformed distribution * Initializing loss function * Training
0epoch [00:00, ?epoch/s]
0batch [00:00, ?batch/s]
average_flow_2.training_plot()
As we can see the training plots are substantially more complicated as we are monitoring several additional quantities.
ev, eer = average_flow_2.evidence()
print(f'log(Z) = {ev} +- {eer}')
log(Z) = 0.10600437223911285 +- 0.44058406352996826
validation_flow_log10_P = average_flow_2.log_probability(average_flow_2.cast(chain.samples[average_flow_2.test_idx, :])).numpy()/np.log(10.)
validation_samples_log10_P = -chain.loglikes[average_flow_2.test_idx]/np.log(10.) # notice the minus sign due to the definition of logP in getdist
training_flow_log10_P = average_flow_2.log_probability(average_flow_2.cast(chain.samples[average_flow_2.training_idx, :])).numpy()/np.log(10.)
training_samples_log10_P = -chain.loglikes[average_flow_2.training_idx]/np.log(10.) # notice the minus sign due to the definition of logP in getdist
# do the plot:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
ax1.scatter(training_samples_log10_P - np.amax(training_samples_log10_P), training_flow_log10_P - training_samples_log10_P, s=0.1, label='training')
ax1.scatter(validation_samples_log10_P - np.amax(validation_samples_log10_P), validation_flow_log10_P - validation_samples_log10_P, s=0.5, label='validation')
ax1.legend()
ax1.axhline(0, color='k', linestyle='--')
ax1.set_xlabel('$\log_{10}(P_{\\rm true}/P_{\\rm max})$')
ax1.set_ylabel('$\log_{10}(P_{\\rm flow}) - \log_{10}(P_{\\rm true})$')
ax1.set_ylim(-1.0, 1.0)
ax2.hist(training_flow_log10_P - training_samples_log10_P, bins=50, range=[-1., 1.], density=True, alpha=0.5, label='training')
ax2.hist(validation_flow_log10_P - validation_samples_log10_P, bins=50, range=[-1., 1.], density=True, alpha=0.5, label='validation')
ax2.legend()
ax2.axvline(0, color='k', linestyle='--')
ax2.set_xlabel('$\log_{10}(P_{\\rm flow}) - \log_{10}(P_{\\rm true})$')
ax2.set_xlim([-1.0, 1.0])
plt.tight_layout()
plt.show()
As we can see this achieves performances that are very close to averaging flows. Combining the two strategies achieves the best performances.
Advanced Topic: Spline Flows¶
When more flexibility in the normalizing flow model is needed we provide an implementation of neural spline flows as discussed in Durkan et al (2019), arXiv:1906.04032.
kwargs = {
# flow settings:
'pop_size': 1,
'num_flows': 1,
'epochs': 400,
'transformation_type': 'spline',
'autoregressive_type': 'masked',
# feedback flags:
'feedback': 1,
'verbose': -1,
'plot_every': 1000,
}
spline_flow = tensiometer.synthetic_probability.flow_from_chain(chain, # parameter difference chain
**kwargs)
* Initializing samples
* Initializing fixed bijector
* Initializing trainable bijector
WARNING: range_max should be larger than the maximum range of the data and is beeing adjusted.
range_max: 5.0
max range: 11.324654579162598
new range_max: 12.324655
* Initializing training dataset
* Initializing transformed distribution
* Initializing loss function
* Training
0epoch [00:00, ?epoch/s]
0batch [00:00, ?batch/s]
# we can plot training summaries to make sure training went smoothly:
spline_flow.training_plot()
# we can triangle plot the flow to see how well it has learned the target distribution:
g = plots.get_subplot_plotter()
g.triangle_plot([chain,
spline_flow.MCSamples(20000) # this flow method returns a MCSamples object
],
params=flow.param_names,
filled=True)
samples = spline_flow.MCSamples(20000)
logP = spline_flow.log_probability(spline_flow.cast(samples.samples)).numpy()
samples.addDerived(logP, name='logP', label='\\log P')
samples.updateBaseStatistics();
# now let's plot everything:
g = plots.get_subplot_plotter()
g.triangle_plot([samples, chain],
plot_3d_with_param='logP',
filled=False)
validation_flow_log10_P = spline_flow.log_probability(spline_flow.cast(chain.samples[spline_flow.test_idx, :])).numpy()/np.log(10.)
validation_samples_log10_P = -chain.loglikes[spline_flow.test_idx]/np.log(10.) # notice the minus sign due to the definition of logP in getdist
training_flow_log10_P = spline_flow.log_probability(spline_flow.cast(chain.samples[spline_flow.training_idx, :])).numpy()/np.log(10.)
training_samples_log10_P = -chain.loglikes[spline_flow.training_idx]/np.log(10.) # notice the minus sign due to the definition of logP in getdist
ev, eer = spline_flow.evidence()
print(f'log(Z) = {ev} +- {eer}')
# do the plot:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
ax1.scatter(training_samples_log10_P - np.amax(training_samples_log10_P), training_flow_log10_P - training_samples_log10_P, s=0.1, label='training')
ax1.scatter(validation_samples_log10_P - np.amax(validation_samples_log10_P), validation_flow_log10_P - validation_samples_log10_P, s=0.5, label='validation')
ax1.legend()
ax1.axhline(0, color='k', linestyle='--')
ax1.set_xlabel('$\log_{10}(P_{\\rm true}/P_{\\rm max})$')
ax1.set_ylabel('$\log_{10}(P_{\\rm flow}) - \log_{10}(P_{\\rm true})$')
ax1.set_ylim(-1.0, 1.0)
ax2.hist(training_flow_log10_P - training_samples_log10_P, bins=50, range=[-1., 1.], density=True, alpha=0.5, label='training')
ax2.hist(validation_flow_log10_P - validation_samples_log10_P, bins=50, range=[-1., 1.], density=True, alpha=0.5, label='validation')
ax2.legend()
ax2.axvline(0, color='k', linestyle='--')
ax2.set_xlabel('$\log_{10}(P_{\\rm flow}) - \log_{10}(P_{\\rm true})$')
ax2.set_xlim([-1.0, 1.0])
plt.tight_layout()
plt.show()
log(Z) = 0.13358823955059052 +- 0.9221242070198059
We can check what happens across the bijector layers:
from tensiometer.synthetic_probability import flow_utilities as flow_utils
training_samples_spaces, validation_samples_spaces = \
flow_utils.get_samples_bijectors(spline_flow,
feedback=True)
for i, _s in enumerate(training_samples_spaces):
print('* ', _s.name_tag)
g = plots.get_subplot_plotter()
g.triangle_plot([
training_samples_spaces[i],
validation_samples_spaces[i]],
filled=True,
)
plt.show()
0 - bijector name: permute 1 - bijector name: spline_flow 2 - bijector name: permute 3 - bijector name: spline_flow 4 - bijector name: permute 5 - bijector name: spline_flow 6 - bijector name: permute 7 - bijector name: spline_flow 8 - bijector name: permute 9 - bijector name: spline_flow 10 - bijector name: permute 11 - bijector name: spline_flow 12 - bijector name: permute 13 - bijector name: spline_flow 14 - bijector name: permute 15 - bijector name: spline_flow * training_space
* 0_after_permute
* 1_after_spline_flow
* 2_after_permute
* 3_after_spline_flow
* 4_after_permute
* 5_after_spline_flow
* 6_after_permute
* 7_after_spline_flow
* 8_after_permute
* 9_after_spline_flow
* 10_after_permute
* 11_after_spline_flow
* 12_after_permute
* 13_after_spline_flow
* 14_after_permute
* 15_after_spline_flow